{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from torch.utils import data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Dataset:\n",
    "    def __init__(self, **datas):\n",
    "        self.test = datas.get('test', False)\n",
    "        \n",
    "        self.title = datas['title']\n",
    "        self.desc = datas['desc']\n",
    "\n",
    "        # train为True，返回对应的label\n",
    "        if self.test is False:\n",
    "            self.n_classes = datas['class_num']\n",
    "            self.label = datas['label']\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        title = torch.from_numpy(self.title[idx])\n",
    "        desc = torch.from_numpy(self.desc[idx])\n",
    "        if self.test is False:\n",
    "            label = torch.zeros(self.n_classes).scatter_(0, torch.from_numpy(self.label[idx]).long(), 1)\n",
    "            return torch.cat((title, desc)), label\n",
    "        return torch.cat((title, desc))\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.title.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word_embed_mat = np.load('../../data_preprocess/embed/word_embed_mat.npy')\n",
    "train_title_word = np.load('../../data_preprocess/train/train_title_word_indices.npy')\n",
    "train_desc_word = np.load('../../data_preprocess/train/train_desc_word_indices.npy')\n",
    "train_label = np.load('../../data_preprocess/train/train_label_indices.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_dataset = Dataset(title=train_title_word, desc=train_desc_word, label=train_label, class_num=1999)\n",
    "train_loader = data.DataLoader(train_dataset, shuffle=True, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 71]) torch.Size([64, 1999])\n"
     ]
    }
   ],
   "source": [
    "for i, datas in enumerate(train_loader):\n",
    "    print datas[0].size(), datas[1].size()\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "val_title_word = np.load('../../data_preprocess/val/val_title_word_indices.npy')\n",
    "val_desc_word = np.load('../../data_preprocess/val/val_desc_word_indices.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "val_dataset = Dataset(test=True, title=val_title_word, desc=val_desc_word)\n",
    "val_loader = data.DataLoader(val_dataset, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 71])\n"
     ]
    }
   ],
   "source": [
    "for i, datas in enumerate(val_loader):\n",
    "    print datas.size()\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_title_word_all = np.load('../../data_preprocess/train_all/train_title_word_indices_all.npy')\n",
    "train_desc_word_all = np.load('../../data_preprocess/train_all/train_desc_word_indices_all.npy')\n",
    "train_label_all = np.load('../../data_preprocess/train_all/train_label_indices_all.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_dataset_all = Dataset(title=train_title_word_all, desc=train_desc_word_all, label=train_label_all, class_num=1999)\n",
    "train_loader_all = data.DataLoader(train_dataset_all, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 71]) torch.Size([64, 1999])\n"
     ]
    }
   ],
   "source": [
    "for i, datas in enumerate(train_loader_all):\n",
    "    print datas[0].size(), datas[1].size()\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_title_word = np.load('../../data_preprocess/test/test_title_word_indices.npy')\n",
    "test_desc_word = np.load('../../data_preprocess/test/test_desc_word_indices.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_dataset = Dataset(test=True, title=test_title_word, desc=test_desc_word)\n",
    "test_loader = data.DataLoader(test_dataset, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 71])\n"
     ]
    }
   ],
   "source": [
    "for i, datas in enumerate(test_loader):\n",
    "    print datas.size()\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Stack_Dataset:\n",
    "    def __init__(self, **datas):\n",
    "        self.test = datas.get('test', False)\n",
    "#         self.resmat = [torch.load(ii) for ii in datas['resmat']]\n",
    "        self.resmat = datas['resmat']\n",
    "        self.stack_num = len(self.resmat)\n",
    "        # train=True, return label\n",
    "        if self.test is False:\n",
    "            self.label = datas['label']\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        res = tuple([ii[idx] for ii in self.resmat])\n",
    "        if self.test is False:\n",
    "            return res + (self.label[idx],)\n",
    "        return res\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.resmat[0].size(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_dir = '/home/dyj/'\n",
    "resmat = [result_dir+'RNN_2017-07-27#10:48:05_res.pt', result_dir+'TextCNN_2017-07-27#10:15:20_res.pt']\n",
    "label = result_dir+'label.pt'\n",
    "resmat = [torch.load(ii) for ii in resmat]\n",
    "label = torch.load(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "stack_train_dataset = Stack_Dataset(resmat=resmat, label=label)\n",
    "stack_train_loader = data.DataLoader(stack_train_dataset, shuffle=True, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "string index out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0mTraceback (most recent call last)",
      "\u001b[0;32m<ipython-input-14-f3a3707087ea>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstack_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m     \u001b[0mresmat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0;32mprint\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresmat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresmat\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.pyc\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    188\u001b[0m                 \u001b[0;32mraise\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m             \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_indices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    191\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    192\u001b[0m                 \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-11-733a4071e01c>\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mii\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mii\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresmat\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mFalse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mIndexError\u001b[0m: string index out of range"
     ]
    }
   ],
   "source": [
    "for i, batch in enumerate(stack_train_loader, 0):\n",
    "    resmat, label = batch\n",
    "    print len(resmat), resmat[0].size(), label.size()\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
